import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
import numpy as np
from transformers import AutoModelForCausalLM, AutoTokenizer
import tensorstore as ts
from tqdm import tqdm
import pickle


def append_to_list_pickle(filename, new_items):
    # 1. Load existing list (or start fresh)
    if os.path.exists(filename):
        with open(filename, 'rb') as f:
            try:
                data = pickle.load(f)
            except EOFError:
                data = []
    else:
        data = []

    # 2. Extend it
    data.extend(new_items)

    # 3. Overwrite the file with the updated list
    with open(filename, 'wb') as f:
        pickle.dump(data, f, protocol=pickle.HIGHEST_PROTOCOL)
        
# 1. Data Collection
class EmbeddingDataCollector:
    def __init__(self, model_name="google/gemma-7b", output_dir="~/dual-map/"):
        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        
        # Load model and tokenizer
        self.tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir = "~/gemma_cache")
        self.model = AutoModelForCausalLM.from_pretrained(
            model_name, 
            # torch_dtype=torch.float16
            output_hidden_states=True,
            device_map='auto',
            cache_dir = "~/gemma_cache"
        ).eval()
        
        # Create output directory
        os.makedirs(output_dir, exist_ok=True)
        self.output_dir = output_dir

    def collect_data(self, tokens_data, batch_size=32, filename='embedding_data_large.pkl', num_samples=10000, save_every=1000):
        """Collect (x, y) pairs where:
           x = last token embedding from last layer
           y = expectation of output token embeddings based on distribution
           
        Args:
            tokens_data: Input token data
            batch_size: Batch size for processing
            num_samples: Total number of samples to collect
            save_every: Save data after collecting this many new samples
        """
        
        final_filename = os.path.join(self.output_dir, filename)
        self.final_filename = final_filename

        if os.path.exists(final_filename):
            with open(final_filename, 'rb') as f:
                data_pairs = pickle.load(f)
                print(f"Loaded {len(data_pairs)} existing data pairs")
                last_save_count   = len(data_pairs) // 10
                starting_pos = last_save_count
        else:
            last_save_count = 0
            starting_pos = 0
            data_pairs = []
            with open(final_filename, 'wb') as f:
                pickle.dump(data_pairs, f)
        
        # Get the embedding matrix (token embeddings)
        embedding_matrix = self.model.get_input_embeddings().weight.detach().to(self.device)
        original_g = self.model.get_input_embeddings().weight
        g_mean = original_g.mean(axis=0)
        g_modfied = original_g - g_mean
        
        u, s, vt = torch.linalg.svd(g_modfied, full_matrices=False)
        
        g_mean = g_mean.detach()
        whitening_matrix = torch.matmul(
            torch.matmul(vt.T, torch.diag(1.0 / torch.sqrt(s + 1e-6))),
            vt
        ).detach()
        
        # Track when we last saved
        current_count = last_save_count 
        
        # Process tokens in batches
        for i in tqdm(range(starting_pos, min(num_samples, tokens_data.shape[0]), batch_size)):
            batch_sequences = tokens_data[i:i+batch_size].read().result()
            
            # Add BOS token to the beginning of each sequence
            batch_size_actual, seq_len = batch_sequences.shape
            padded_sequences = torch.zeros((batch_size_actual, seq_len + 1), dtype=torch.long)
            padded_sequences = padded_sequences.to(self.device)
            padded_sequences[:, 0] = self.tokenizer.bos_token_id
            padded_sequences[:, 1:] = torch.from_numpy(batch_sequences)
            
            # Get model outputs for the entire batch at once
            with torch.no_grad():
                outputs = self.model(padded_sequences, output_hidden_states=True)
            
            # Process each token in each sequence
            last_hidden_states = outputs.hidden_states[-1]  # shape: [batch_size, seq_len+1, hidden_dim]
            logits = outputs.logits  # shape: [batch_size, seq_len+1, vocab_size]

            mask = (padded_sequences != 0)[:, 1:]  # Shape: [batch_size, seq_len]
            
            for b in range(batch_size_actual):
                valid_indices = torch.nonzero(mask[b]).squeeze(-1)
            
                if valid_indices.numel() == 0:
                    continue
                    
                # Get token embeddings for all valid tokens in this sequence
                # token_positions = valid_indices + 1  # Add 1 to account for BOS offset
                # downsample_factor = 25
                # base_positions = valid_indices + 1               # original positions
                # token_positions  = base_positions[::downsample_factor]

                downsample_factor = 25
                base_positions = valid_indices + 1              # original positions
                
                # decide how many to sample (floor division)
                num_to_sample = max(1, len(base_positions) // downsample_factor)
                
                # randomly permute and take the first num_to_sample indices
                perm = torch.randperm(len(base_positions), device=base_positions.device)
                selected = perm[:num_to_sample]
                
                # pick those positions and grab their embeddings
                token_positions  = base_positions[selected]
                token_embeddings = last_hidden_states[b, token_positions]
                
                # Get output distributions for these tokens
                token_logits_batch = logits[b, token_positions]
                probs_batch = torch.softmax(token_logits_batch, dim=-1)
                
                # Compute expected output token embeddings (vectorized)
                expected_embeddings = torch.matmul(probs_batch, embedding_matrix)
                expected_embeddings_causal = (expected_embeddings - g_mean) @ whitening_matrix * np.sqrt(original_g.shape[0] / original_g.shape[1])
                
                # Convert to numpy efficiently (in one batch)
                token_embeddings_np = token_embeddings.to(torch.float32).cpu().numpy()
                expected_embeddings_np = expected_embeddings_causal.to(torch.float32).cpu().numpy()
                
                # Add all pairs from this sequence
                new_pairs = list(zip(token_embeddings_np, expected_embeddings_np))
                data_pairs.extend(new_pairs)
                    
                # Periodically save the collected data to the same file
                current_count += 1
                if current_count - last_save_count >= save_every:
                    with open(final_filename, 'wb') as f:
                        pickle.dump(data_pairs, f)
                    print(f"Saved progress: {current_count} data pairs")
                    last_save_count = current_count
                
                # Stop if we've collected enough samples
                if current_count >= num_samples:
                    # Save the final collected data
                    with open(final_filename, 'wb') as f:
                        pickle.dump(data_pairs, f)
                    print(f"Collected {len(data_pairs)} data pairs")
                    print(final_filename)
                    return final_filename
                    
        # Save the final collected data
        with open(final_filename, 'wb') as f:
            pickle.dump(data_pairs, f)
        
        print(f"Collected {len(data_pairs)} data pairs")
        print(final_filename)
        return final_filename
        
    # def collect_data(self, tokens_data, batch_size=32, filename='embedding_data_large.pkl', num_samples=10000, save_every=1000):
    #     """Collect (x, y) pairs where:
    #        x = last token embedding from last layer
    #        y = expectation of output token embeddings based on distribution
           
    #     Args:
    #         tokens_data: Input token data
    #         batch_size: Batch size for processing
    #         num_samples: Total number of samples to collect
    #         save_every: Save data after collecting this many new samples
    #     """
        
    #     final_filename = os.path.join(self.output_dir, filename)
    #     self.final_filename = final_filename

    #     if os.path.exists(final_filename):
    #         with open(final_filename, 'rb') as f:
    #             data_pairs = pickle.load(f)
    #             print(f"Loaded {len(data_pairs)} existing data pairs")
    #             last_save_count   = len(data_pairs) // 10
    #             del data_pairs
    #             data_pairs = []
    #             starting_pos = last_save_count
    #     else:
    #         last_save_count = 0
    #         starting_pos = 0
    #         data_pairs = []
    #         with open(final_filename, 'wb') as f:
    #             pickle.dump(data_pairs, f)
        
    #     # Get the embedding matrix (token embeddings)
    #     embedding_matrix = self.model.get_input_embeddings().weight.detach().to(self.device)
    #     original_g = self.model.get_input_embeddings().weight
    #     g_mean = original_g.mean(axis=0)
    #     g_modfied = original_g - g_mean
        
    #     u, s, vt = torch.linalg.svd(g_modfied, full_matrices=False)
        
    #     g_mean = g_mean.detach()
    #     whitening_matrix = torch.matmul(
    #         torch.matmul(vt.T, torch.diag(1.0 / torch.sqrt(s + 1e-6))),
    #         vt
    #     ).detach()
        
    #     # Track when we last saved
    #     current_count = last_save_count 
        
    #     # Process tokens in batches
    #     for i in tqdm(range(0, min(num_samples, tokens_data.shape[0]), batch_size)):
    #         batch_sequences = tokens_data[starting_pos+i:starting_pos+i+batch_size].read().result()
            
    #         # Add BOS token to the beginning of each sequence
    #         batch_size_actual, seq_len = batch_sequences.shape
    #         padded_sequences = torch.zeros((batch_size_actual, seq_len + 1), dtype=torch.long)
    #         padded_sequences = padded_sequences.to(self.device)
    #         padded_sequences[:, 0] = self.tokenizer.bos_token_id
    #         padded_sequences[:, 1:] = torch.from_numpy(batch_sequences)
            
    #         # Get model outputs for the entire batch at once
    #         with torch.no_grad():
    #             outputs = self.model(padded_sequences, output_hidden_states=True)
            
    #         # Process each token in each sequence
    #         last_hidden_states = outputs.hidden_states[-1]  # shape: [batch_size, seq_len+1, hidden_dim]
    #         logits = outputs.logits  # shape: [batch_size, seq_len+1, vocab_size]

    #         mask = (padded_sequences != 0)[:, 1:]  # Shape: [batch_size, seq_len]
            
    #         for b in range(batch_size_actual):
    #             valid_indices = torch.nonzero(mask[b]).squeeze(-1)
            
    #             if valid_indices.numel() == 0:
    #                 continue
                    
    #             # Get token embeddings for all valid tokens in this sequence
    #             # token_positions = valid_indices + 1  # Add 1 to account for BOS offset
    #             # downsample_factor = 25
    #             # base_positions = valid_indices + 1               # original positions
    #             # token_positions  = base_positions[::downsample_factor]

    #             downsample_factor = 25
    #             base_positions = valid_indices + 1              # original positions
                
    #             # decide how many to sample (floor division)
    #             num_to_sample = max(1, len(base_positions) // downsample_factor)
                
    #             # randomly permute and take the first num_to_sample indices
    #             perm = torch.randperm(len(base_positions), device=base_positions.device)
    #             selected = perm[:num_to_sample]
                
    #             # pick those positions and grab their embeddings
    #             token_positions  = base_positions[selected]
    #             token_embeddings = last_hidden_states[b, token_positions]
                
    #             # Get output distributions for these tokens
    #             token_logits_batch = logits[b, token_positions]
    #             probs_batch = torch.softmax(token_logits_batch, dim=-1)
                
    #             # Compute expected output token embeddings (vectorized)
    #             expected_embeddings = torch.matmul(probs_batch, embedding_matrix)
    #             expected_embeddings_causal = (expected_embeddings - g_mean) @ whitening_matrix * np.sqrt(original_g.shape[0] / original_g.shape[1])
                
    #             # Convert to numpy efficiently (in one batch)
    #             token_embeddings_np = token_embeddings.to(torch.float32).cpu().numpy()
    #             expected_embeddings_np = expected_embeddings_causal.to(torch.float32).cpu().numpy()
                
    #             # Add all pairs from this sequence
    #             new_pairs = list(zip(token_embeddings_np, expected_embeddings_np))
    #             data_pairs.extend(new_pairs)
                    
    #             # Periodically save the collected data to the same file
    #             current_count += 1
    #             if current_count - last_save_count >= save_every:
    #                 # with open(final_filename, 'ab') as f:
    #                 #     for pair in data_pairs:
    #                 #         pickle.dump(pair, f)
    #                 append_to_list_pickle(final_filename, data_pairs)
    #                 print(f"Saved progress: {current_count} data pairs")
    #                 del data_pairs
    #                 data_pairs = []
    #                 last_save_count = current_count
                
    #             # Stop if we've collected enough samples
    #             if current_count >= num_samples:
    #                 # Save the final collected data
    #                 # with open(final_filename, 'ab') as f:
    #                 #     print(data_pairs)
    #                 #     for pair in data_pairs:
    #                 #         pickle.dump(pair, f)
    #                 append_to_list_pickle(final_filename, data_pairs)
    #                 print(f"Collected {len(data_pairs)} data pairs")
    #                 print(final_filename)
    #                 return final_filename
                    
    #     # Save the final collected data
    #     # with open(final_filename, 'ab') as f:
    #     #     for pair in data_pairs:
    #     #         pickle.dump(pair, f)
    #     append_to_list_pickle(final_filename, data_pairs)
        
    #     print(f"Collected {len(data_pairs)} data pairs")
    #     print(final_filename)
    #     return final_filename

# 2. Dataset and DataLoader
class EmbeddingDataset(Dataset):
    def __init__(self, data_path):
        """
        Args:
            data_path: Path to the saved embedding data
        """
        print(f"Loading data from {data_path}")
        with open(data_path, 'rb') as f:
            self.data = pickle.load(f)
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        x, y = self.data[idx]
        return torch.tensor(x, dtype=torch.float32), torch.tensor(y, dtype=torch.float32)

# 3. MLP Model
class MLP(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(MLP, self).__init__()
        self.input_dim = input_dim
        self.layers = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, output_dim)
        )
    
    def forward(self, x):
        x /= np.sqrt(self.input_dim)
        return self.layers(x)

# 4. Training Function
def train_model(model, train_loader, val_loader, epochs=10, lr=1e-4, save_path="~/dual-map/best_mlp_model.pt"):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    
    criterion = nn.MSELoss()
    optimizer = optim.AdamW(model.parameters(), lr=lr)
    
    best_val_loss = float('inf')
    
    for epoch in range(epochs):
        # Training
        model.train()
        train_loss = 0
        for x, y in tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs} [Train]"):
            x, y = x.to(device), y.to(device)
            
            optimizer.zero_grad()
            output = model(x)
            loss = criterion(output, y)
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()
            
        train_loss /= len(train_loader)
        
        # Validation
        model.eval()
        val_loss = 0
        with torch.no_grad():
            for x, y in tqdm(val_loader, desc=f"Epoch {epoch+1}/{epochs} [Val]"):
                x, y = x.to(device), y.to(device)
                output = model(x)
                val_loss += criterion(output, y).item()
        
        val_loss /= len(val_loader)
        
        print(f"Epoch {epoch+1}/{epochs}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")
        
        # Save the best model
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), save_path)
            print(f"Model saved to {save_path}")
    
    return model

# Main execution
def main():
    # 1. Load token data 
    tokens_data = ts.open(
        {
        'driver': 'zarr3',
        'cache_pool': {'total_bytes_limit': 1E9},
        'recheck_cached_data': 'open',
        'kvstore': {
            'driver': 'file',
            'file_io_concurrency': {'limit': 2048},
            'path': '/your/path/to/tokens',
            },
        },
        dtype=ts.int64,
        chunk_layout=ts.ChunkLayout(
            write_chunk_shape=[10240, 254],
        ),
        shape=[3921600, 254],
    ).result()
    
    # 2. Collect embedding data
    print("Collect embedding data")
    collector = EmbeddingDataCollector()
    final_filename = collector.collect_data(tokens_data, num_samples=200000, filename='embedding_data_large_random_saved.pkl')  # Adjust sample size as needed
    # final_filename = collector.collect_data(tokens_data, num_samples=1, filename='embedding_data_append_test.pkl')  # Adjust sample size as needed


if __name__ == "__main__":
    main()